import numpy as np
import pandas as pd
import cvxpy as cp
import matplotlib.pyplot as plt


def build_network(Ab, Vb, network_data):
    Ib = Ab / (np.sqrt(3) *Vb) # Base current
    # Read the network data
    linedata = np.loadtxt(network_data)
    Voltage_Up = (linedata[:,0]-1).astype(int) # Node indices upstream of the lines
    Voltage_Down = (linedata[:,1]-1).astype(int) # Node indices downstream of the lines
    linedata[:,4] = linedata[:,4]*(1e-6) # Convert the shunt susceptance to Siemens
    n_lines = len(linedata)  # number of lines
    n_nodes = int(max(np.max(linedata[:, 0]), np.max(linedata[:, 1])))  # number of nodes
    print(f"The network consists of {n_nodes} nodes and {n_lines} lines.")

    # Base values of the network
    Yb = Ab / (Vb ** 2)  # Base admittance
    line_lengths = linedata[:, 5]
    R = (linedata[:, 2] * Yb) * line_lengths
    X = (linedata[:, 3] * Yb) * line_lengths
    B = -(linedata[:, 4] / Yb) * line_lengths # !!! Assuming capacitive lines!!!

    # Building the primitive branch and shunt admittance matrices
    y_i_ih = np.zeros((n_nodes, n_nodes), dtype=complex)
    y_i = np.zeros(n_nodes, dtype=complex)

    y_ih = 1 / (R + 1j * X)

    # Compute the nodal shunt admittances
    Bn = np.zeros(n_nodes) # Nodal shunt is equal to the sum of the shunt admittances of the lines connected to the node
    for k in range(n_lines):
        Bn[int(linedata[k, 0]) - 1] += B[k] / 2
        Bn[int(linedata[k, 1]) - 1] += B[k] / 2
        y_i_ih[int(linedata[k, 0]) - 1, int(linedata[k, 1]) - 1] = B[k] * 1j / 2
        y_i_ih[int(linedata[k, 1]) - 1, int(linedata[k, 0]) - 1] = B[k] * 1j / 2

    for k in range(n_nodes):
        y_i[k] = np.sum(y_i_ih[k, :])

    YL = np.diag(y_ih)  # Line admittance matrix
    YT = np.diag(y_i)  # Shunt admittance at each node
    YLT = np.diag(B) *1j/2  # Shunt contribution at the sending end of a line

    # Compute incidence matrix of the network
    A = np.zeros((n_lines, n_nodes))
    for k in range(n_lines):
        A[k, int(linedata[k, 0]) - 1] = 1
        A[k, int(linedata[k, 1]) - 1] = -1

    # Compute the full admittance matrix
    Y = (A.T @YL) @ A + YT
    YYT = -((A.T @YLT) @ A - np.diag(np.diag(A.T @ YLT @ A)))

    # Compute the incidence matrices as defined for the SOCP relaxation
    Ap = np.zeros((n_nodes, n_lines))
    Am = np.zeros((n_nodes, n_lines))
    for k in range(n_lines):
        Ap[int(linedata[k, 0]) - 1, k] = 1
        Ap[int(linedata[k, 1]) - 1, k] = -1
        Am[int(linedata[k, 0]) - 1, k] = 0
        Am[int(linedata[k, 1]) - 1, k] = -1

    Adown = Ap.clip(min=0) # Lines downstream of nodes
    Aup = -Ap.clip(max=0) # Lines upstream of nodes

    # Get the ampacity limits in per unit
    Imax = linedata[:,6]/Ib
    Imax_sq = Imax**2

    # Export the network information
    network = {'A': A, 'YY': Y, 'R': R, 'X': X, 'B': B, 'YL': YL, 'YT': YT, 'YLT': YLT, 'YYT':YYT, 'linedata': linedata
               , 'Ap': Ap, 'Am': Am, 'Bn': Bn, 'n_nodes': n_nodes, 'n_lines': n_lines, 'Imax_squared': Imax_sq, 'Imax': Imax
               , 'Voltage_Up': Voltage_Up, 'Voltage_Down': Voltage_Down, 'Adown':Adown, 'Aup': Aup}
    return network


def sensitivity_coefficients(network, E, S):
    """ This function computes the voltage, current and slack power sensitivity coefficients
    with respect to the active and reactive power injections at the buses
    """
    # Extract the network data
    A = network['A']
    YY = network['YY']
    YL = np.diag(network['YL'])
    YLT = network['YLT']
    YYT = network['YYT']
    line_data = network['linedata']
    Rl = network['R']; Xl = network['X']; Bn = network['Bn']; Bl = network['B']
    # First we compute the voltage sensitivity coefficients
    # Define helper variables
    YV = YY @ E
    YV = YV[1:]
    # The slack node is excluded as it is fixed
    Efull = E.copy()
    YY = YY[1:, 1:]; E = E[1:]; S = S[1:]
    n = len(E)
    n_lines = n # For a radial network there are n-1 lines --> We already removed the slack node!!!
    
    VY = np.empty((n, n), dtype=complex)
    for i in range(n):
        VY[i,:] = np.conj(E[i]) * YY[i, :] 

    # Build the coefficient matrix required to obtain the voltage sensitivities
    A = np.empty((2*n, 2*n), dtype=complex)
    for i in range(n):
        for l in range(n):
            if i==l:
                # Real Equation
                # Coefficients multiplied with the real part of the SC
                A[i,l] = np.real(YV[i]) + np.real(VY[i,l]) 
                # Coefficients multiplied with the imaginary part of the SC
                A[i,l+n] = np.imag(YV[i]) - np.imag(VY[i,l])
                # Imaginary Equation    
                # Coefficients multiplied with the real part of the SC
                A[n+i,l] = np.imag(YV[i]) + np.imag(VY[i,l])
                # Coefficients multiplied with the imaginary part of the SC
                A[n+i,l+n] = -np.real(YV[i]) + np.real(VY[i,l])
            else:
                # Real Equation
                # Coefficients multiplied with the real part of the SC
                A[i,l] = np.real(VY[i,l])
                # Coefficients multiplied with the imaginary part of the SC
                A[i,l+n] = -np.imag(VY[i,l])
                 # Imaginary Equation    
                # Coefficients multiplied with the real part of the SC
                A[n+i,l] = np.imag(VY[i,l])
                # Coefficients multiplied with the imaginary part of the SC
                A[n+i,l+n] = np.real(VY[i,l])

    # Compute the voltage sensitivities with respect to the active power injections
    K_V_P = np.zeros((n, n), dtype=complex); K_V_Q = np.zeros((n, n), dtype=complex)
    K_V_P_mag = np.zeros((n, n)); K_V_Q_mag = np.zeros((n, n))
    for l in range(n):
        # Sensitivity coefficients for the active power injections
        # Construct the RHS of the equation
        rhs = np.zeros(2*n)
        rhs[l] = 1
        # Compute the sensistivity coefficients with respect to injection l
        y = np.linalg.solve(A, rhs)
        # Store the results
        K_V_P[:,l] = y[:n] + 1j*y[n:]
        # Also compute the sensitivity of the voltage magnitude
        K_V_P_mag[:,l] = 1/np.abs(E) * np.real(np.conj(E) * K_V_P[:,l])

        # Sensitivity coefficients for the reactive power injections
        # Construct the RHS of the equation
        rhs = np.zeros(2*n)
        rhs[l+n] = -1
        # Compute the sensistivity coefficients with respect to injection l
        y = np.linalg.solve(A, rhs)
        # Store the results
        K_V_Q[:,l] = y[:n] + 1j*y[n:]
        # Also compute the sensitivity of the voltage magnitude
        K_V_Q_mag[:,l] = 1/np.abs(E) * np.real(np.conj(E) *  K_V_Q[:,l])

    # Now we compute the current sensitivity coefficients
    # Create a temporary extension for the voltage sensitivities to include the slack node
    K_V_P_full = np.zeros((n+1,n), dtype=complex); K_V_Q_full = np.zeros((n+1,n), dtype=complex)
    K_V_P_full[1:,:] = K_V_P; K_V_Q_full[1:,:] = K_V_Q
    # First get the nodes where the lines start and end
    start_nodes = line_data[:,0].astype(int)-1; end_nodes = line_data[:,1].astype(int)-1 # Substract 1 for python zero-indexing
    # Compute the currents
    I = YL * (network['A'] @ Efull) + YLT @ (np.clip(network['A'], a_min=0, a_max =1) @ Efull)
    I_m = YL * (network['A'] @ Efull) # Only the part flowing through the longitudinal components
    Imag = np.absolute(I); Vmag = np.absolute(E); Imag_m = np.absolute(I_m)

    # Compute the current sensitivity coefficients with respect to the active power injections
    K_I_P = np.zeros((n, n),  dtype=complex); K_I_Q = np.zeros((n, n),  dtype=complex); K_I_P_mag = np.zeros((n, n)); K_I_Q_mag = np.zeros((n, n))
    K_I_P_m = np.zeros((n, n),  dtype=complex); K_I_Q_m = np.zeros((n, n),  dtype=complex); K_I_P_mag_m = np.zeros((n, n)); K_I_Q_mag_m = np.zeros((n, n))
    for l in range(n_lines): # Loop over the lines for which to compute the sensitivity coefficients
        for nn in range(n): # Loop over the nodes affecting the currents
            K_I_P[l,nn] = YL[l] * (K_V_P_full[start_nodes[l],nn] - K_V_P_full[end_nodes[l],nn]) + YLT[l,l] * K_V_P_full[start_nodes[l],nn]
            K_I_Q[l,nn] = YL[l] * (K_V_Q_full[start_nodes[l],nn] - K_V_Q_full[end_nodes[l],nn]) + YLT[l,l] * K_V_Q_full[start_nodes[l],nn]
            # Alternatively
            K_I_P[l,nn] = YL[l] * (K_V_P_full[start_nodes[l],nn] - K_V_P_full[end_nodes[l],nn]) + YYT[start_nodes[l],end_nodes[l]] * K_V_P_full[start_nodes[l],nn]
            K_I_Q[l,nn] = YL[l] * (K_V_Q_full[start_nodes[l],nn] - K_V_Q_full[end_nodes[l],nn]) + YYT[start_nodes[l],end_nodes[l]] * K_V_Q_full[start_nodes[l],nn]
            # Only the longitudinal part
            K_I_P_m[l,nn] = YL[l] * (K_V_P_full[start_nodes[l],nn] - K_V_P_full[end_nodes[l],nn]) 
            K_I_Q_m[l,nn] = YL[l] * (K_V_Q_full[start_nodes[l],nn] - K_V_Q_full[end_nodes[l],nn]) 
            # Also compute the sensitivity of the current magnitude
            #if abs(I[l]) > 1e-6:
            K_I_P_mag_m[l,nn] = 1/np.abs(I[l]) * np.real(np.conj(I[l]) * K_I_P[l,nn])
            K_I_Q_mag_m[l,nn] = 1/np.abs(I[l]) * np.real(np.conj(I[l]) * K_I_Q[l,nn])
            #else: # Approximation for the case where the current is zero
            #    K_I_P_mag[l,nn] = np.abs(K_V_P[start_nodes[l],nn] - K_V_P[end_nodes[l],nn]) * np.abs(YL[l]) + np.abs(YLT[l,l]) * np.abs(K_V_P[start_nodes[l],nn])
            #    K_I_Q_mag[l,nn] = np.abs(K_V_Q[start_nodes[l],nn] - K_V_Q[end_nodes[l],nn]) * np.abs(YL[l]) + np.abs(YLT[l,l]) * np.abs(K_V_Q[start_nodes[l],nn])

    # Finally we compute the slack power sensitivity coefficients
    C_rp = np.zeros(n); C_rq = np.zeros(n); C_xp = np.zeros(n); C_xq = np.zeros(n)
    for l in range(n):
        for i in range(n):
            C_rp[i] += 2*Rl[l] * Imag_m[l] * K_I_P_mag_m[l,i]  # No shunt contribution as shunt conductance is neglected
            C_rq[i] += 2*Rl[l] * Imag_m[l] * K_I_Q_mag_m[l,i]
            C_xp[i] += 2*Xl[l] * Imag_m[l] * K_I_P_mag_m[l,i]
            C_xq[i] += 2*Xl[l] * Imag_m[l] * K_I_Q_mag_m[l,i]
    
    for node in range(n):
        for i in range(n):
            C_xp[i] += 2*Bn[node] * Vmag[node] * K_V_P_mag[node,i]
            C_xq[i] += 2*Bn[node] * Vmag[node] * K_V_Q_mag[node,i]
    
    return K_V_P_mag, K_V_Q_mag, K_I_P_mag, K_I_Q_mag, C_rp, C_rq, C_xp, C_xq


def loadflow(Y, S_star, E_star, E_0, idx, Parameters):
    # ! Validated through comparison with the matlab code of the course
    n_nodes = len(E_0)
    G = np.real(Y)
    B = np.imag(Y)
    E_0[idx['slack']] = E_star[idx['slack']] # set the slack bus voltage magnitude
    # Initialization
    Ere = np.copy(np.real(E_0))
    Eim = np.copy(np.imag(E_0))
    J = None

    for n_iter in range(1,Parameters['n_max']):
        
        Ere_c = np.copy(Ere); Eim_c = np.copy(Eim)
        G_c = np.copy(G); B_c = np.copy(B)
        # Compute nodal voltages/currents/powers
        E = Ere_c + 1j * Eim_c
        I = Y @ E
        S = np.multiply(E,np.conj(I))
        
        ## Mismatch calculation
        
        # Compute the mismatches for the entire network
        dS = np.copy(S_star - S)
        dP = np.real(dS)
        dQ = np.imag(dS)
        dV2 = np.abs(np.copy(E_star))**2 - np.abs(np.copy(E))**2 
        
        # Keep only the relevant mismatches
        dP = np.delete(dP, idx['slack'])
        dQ = np.delete(dQ, np.concatenate((idx['pv'], idx['slack'])).astype(int))
        dV2 = np.delete(dV2, np.concatenate((idx['pq'], idx['slack'])).astype(int))
        
        dF = np.concatenate((dP, dQ, dV2)) # mismatch of the power flow equations
 
        ## Convergence check
        
        if np.max(np.abs(dF)) < Parameters['tol']:
            #print('NR algorithm has converged to a solution!')
            break
        elif n_iter == Parameters['n_max']-1:
            print('NR algorithm reached the maximum number of iterations!')
        
        ## Jacobian construction
        
        # For the sake of simplicity, the blocks of J are constructed
        # for the whole network (i.e., with size n_nodes x n_nodes).
        # The unnecessary rows/columns are removed subsequently
        
        # Initialization
        J_PR = np.zeros((n_nodes, n_nodes)) # derivative: P versus E_re
        J_PX = np.zeros((n_nodes, n_nodes)) # derivative: P versus E_im
        J_QR = np.zeros((n_nodes, n_nodes)) # derivative: Q versus E_re
        J_QX = np.zeros((n_nodes, n_nodes)) # derivative: Q versus E_im
        J_ER = np.zeros((n_nodes, n_nodes)) # derivative: E^2 versus E_re
        J_EX = np.zeros((n_nodes, n_nodes)) # derivative: E^2 versus E_im
        
        # Construction
        for i in range(n_nodes):
            # Diagonal elements (terms outside the sum)
            J_PR[i, i] = np.copy(2 * G_c[i, i] * Ere_c[i])
            J_PX[i, i] = np.copy(2 * G_c[i, i] * Eim_c[i])
            J_QR[i, i] = np.copy(-2 * B_c[i, i] * Ere_c[i])
            J_QX[i, i] = np.copy(-2 * B_c[i, i] * Eim_c[i])
            J_ER[i, i] = np.copy(2 * Ere_c[i])
            J_EX[i, i] = np.copy(2 * Eim_c[i])
            
            for j in range(n_nodes):
                if j != i:
                    # Diagonal elements (terms inside the sum)
                    J_PR[i, i] += np.copy(G_c[i, j] * Ere_c[j] - B_c[i, j] * Eim_c[j])
                    J_PX[i, i] += np.copy(B_c[i, j] * Ere_c[j] + G_c[i, j] * Eim_c[j])
                    J_QR[i, i] -= np.copy(B_c[i, j] * Ere_c[j] + G_c[i, j] * Eim_c[j])
                    J_QX[i, i] += np.copy(G_c[i, j] * Ere_c[j] - B_c[i, j] * Eim_c[j])
                    
                    # Off-diagonal elements
                    J_PR[i, j] = np.copy(G_c[i, j] * Ere_c[i] + B_c[i, j] * Eim_c[i])
                    J_PX[i, j] = np.copy(-B_c[i, j] * Ere_c[i] + G_c[i, j] * Eim_c[i])
                    J_QR[i, j] = np.copy(-B_c[i, j] * Ere_c[i] + G_c[i, j] * Eim_c[i])
                    J_QX[i, j] = np.copy(-G_c[i, j] * Ere_c[i] - B_c[i, j] * Eim_c[i])
        
        # Remove extra rows (i.e., unnecessary equations)
        # slack bus: P & Q & E^2, PV buses: Q, PQ buses: E^2
        
        J_PR = np.delete(J_PR, idx['slack'], axis=0)
        J_PX = np.delete(J_PX, idx['slack'], axis=0)
        
        J_QR = np.delete(J_QR, np.concatenate((idx['pv'], idx['slack'])).astype(int), axis=0)
        J_QX = np.delete(J_QX, np.concatenate((idx['pv'], idx['slack'])).astype(int), axis=0)
        
        J_ER = np.delete(J_ER, np.concatenate((idx['pq'], idx['slack'])).astype(int), axis=0)
        J_EX = np.delete(J_EX, np.concatenate((idx['pq'], idx['slack'])).astype(int), axis=0)
        
        # Remove extra columns (i.e., variables)
        # slack bus: E_re & E_im
        
        J_PR = np.delete(J_PR, idx['slack'], axis=1)
        J_QR = np.delete(J_QR, idx['slack'], axis=1)
        J_ER = np.delete(J_ER, idx['slack'], axis=1)
        
        J_PX = np.delete(J_PX, idx['slack'], axis=1)
        J_QX = np.delete(J_QX, idx['slack'], axis=1)
        J_EX = np.delete(J_EX, idx['slack'], axis=1)
        
        # Combination
        J = np.concatenate((np.concatenate((J_PR, J_PX), axis=1), np.concatenate((J_QR, J_QX), axis=1), 
                            np.concatenate((J_ER, J_EX), axis=1)), axis=0)
       
        J_rank =  np.linalg.matrix_rank(J)
        ## Solution update
        
        # Solve
        dx = np.real(np.linalg.solve(J, dF)) # Take real part to avoid warning casting complex to double (imaginary part is zero)
        
        # Reconstruct the solution
        dE_re = np.zeros(len(Ere))
        dE_re[np.sort(np.concatenate((idx['pq'], idx['pv'])).astype(int))] = np.copy(dx[:len(idx['pq']) + len(idx['pv'])])
   
        dE_im = np.zeros(len(Eim))
        dE_im[np.sort(np.concatenate((idx['pq'], idx['pv'])).astype(int))] = np.copy(dx[len(idx['pq']) + len(idx['pv']):])
        
        # Update
        Ere += np.copy(dE_re)
        Eim += np.copy(dE_im)
    
    E = Ere + 1j * Eim
    slack_bus_angle = np.angle(E_star[idx["slack"]])
    E = np.abs(E) * np.exp(1j * (np.angle(E) + np.ones(len(E)) * slack_bus_angle))
    I = Y @ E
    S = E * np.conj(I) 
    return S, E, J, n_iter